{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import math\n",
    "import logging\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import sklearn\n",
    "import statsmodels.api as sm\n",
    "from statsmodels.formula.api import ols\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set_context(\"poster\")\n",
    "sns.set(rc={'figure.figsize': (16, 9.)})\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "import pandas as pd\n",
    "pd.set_option(\"display.max_rows\", 120)\n",
    "pd.set_option(\"display.max_columns\", 120)\n",
    "\n",
    "logging.basicConfig(level=logging.INFO, stream=sys.stdout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from justcause import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**PLEASE** save this file right now using the following naming convention: `NUMBER_FOR_SORTING-YOUR_INITIALS-SHORT_DESCRIPTION`, e.g. `1.0-fw-initial-data-exploration`. Use the number to order the file within the directory according to its usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quickstart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/fwilhelm/.miniconda3/envs/justcause/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n",
      "/Users/fwilhelm/.miniconda3/envs/justcause/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n",
      "/Users/fwilhelm/.miniconda3/envs/justcause/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pehe_score</th>\n",
       "      <th>mean_absolute</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.918890</td>\n",
       "      <td>0.030689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.795572</td>\n",
       "      <td>0.133251</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.849018</td>\n",
       "      <td>0.123900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   pehe_score  mean_absolute\n",
       "0    0.918890       0.030689\n",
       "1    0.795572       0.133251\n",
       "2    0.849018       0.123900"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from justcause.data.sets import load_ihdp\n",
    "from justcause.learners import SLearner\n",
    "from justcause.learners.propensity import estimate_propensities\n",
    "from justcause.metrics import pehe_score, mean_absolute\n",
    "from justcause.evaluation import calc_scores\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "replications = load_ihdp(select_rep=[0, 1, 2])\n",
    "slearner = SLearner(LinearRegression())\n",
    "metrics = [pehe_score, mean_absolute]\n",
    "scores = []\n",
    "\n",
    "for rep in replications:\n",
    "   train, test = train_test_split(rep, train_size=0.8)\n",
    "   p = estimate_propensities(train.np.X, train.np.t)\n",
    "   slearner.fit(train.np.X, train.np.t, train.np.y, weights=1/p)\n",
    "   pred_ite = slearner.predict_ite(test.np.X, test.np.t, test.np.y)\n",
    "   scores.append(calc_scores(test.np.ite, pred_ite, metrics))\n",
    "\n",
    "pd.DataFrame(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
